# -*- coding: utf-8 -*-
"""
Created on Thu Mar 30 17:14:01 2023

@author: lv
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from utils.UserIntentAttention import UserIntentAttention

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

# 补全输入序列长度
max_seq_length = 50

# 初始化 SelfAttentionEncoder 模型
model = UserIntentAttention(max_seq_length=max_seq_length,
                             output_num_sequences=20,
                             input_max_num_sequences=30,
                             tokenizer_size=tokenizer.vocab_size)

# 输入两句话
text1 = "你好，自然语言处理"
text2 = "这是一条测试SelfAttentionEncoder的语句"
input_texts = [text1, text2]

# 对两句话进行编码
input_ids = []
for text in input_texts:
    encoded_text = tokenizer.encode(text, add_special_tokens=True)
    input_ids.append(encoded_text)

for i in range(len(input_ids)):
    input_ids[i] += [0] * (max_seq_length - len(input_ids[i]))
input_ids = torch.tensor(input_ids)


with open('../templates/intent/intent_template.txt', 'r', encoding='utf-8') as f:
    templates = f.readlines()
intent_templates = [template.strip() for template in templates]

intent_template_ids = []
for text in intent_templates:
    encoded_text = tokenizer.encode(text, add_special_tokens=True)
    intent_template_ids.append(encoded_text)

for i in range(len(intent_template_ids)):
    intent_template_ids[i] += [0] * (max_seq_length - len(intent_template_ids[i]))
intent_template_ids = torch.tensor(intent_template_ids)


# 对输入进行编码
output = model(input_ids,intent_template_ids)

# 对编码结果进行解码
decoded_texts = []
for i in range(output.shape[0]):
    decoded_text = tokenizer.decode(output[i], skip_special_tokens=True)
    decoded_texts.append(decoded_text)

# 输出解码结果
print(decoded_texts)